# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" ARN backbone."""

import mindspore.nn as nn
import mindspore.ops as ops

from mindvision.video.ops import MaxPool3D
from mindvision.engine.class_factory import ClassFactory, ModuleType


class C3DEncoder(nn.Cell):
    """Initialize feature encoder with 4-layer 3D conv. blocks for ARN.

    Args:
        None.

    Returns:
        Tensor, output tensor.
    """

    def __init__(self):
        super(C3DEncoder, self).__init__()
        self.layer1 = nn.SequentialCell(
            nn.Conv3d(3, 64, kernel_size=3, pad_mode='pad', padding=1),
            nn.BatchNorm3d(64, momentum=1, affine=True),
            nn.ReLU(),
            MaxPool3D(2))  # frame/2 x 64 x 64
        self.layer2 = nn.SequentialCell(
            nn.Conv3d(64, 64, kernel_size=3, pad_mode='pad', padding=1),
            nn.BatchNorm3d(64, momentum=1, affine=True),
            nn.ReLU(),
            MaxPool3D(2))  # frame/2 x 32 x 32
        self.layer3 = nn.SequentialCell(
            nn.Conv3d(64, 64, kernel_size=3, pad_mode='pad', padding=1),
            nn.BatchNorm3d(64, momentum=1, affine=True),
            nn.ReLU())
        self.layer4 = nn.SequentialCell(
            nn.Conv3d(64, 64, kernel_size=3, pad_mode='pad', padding=1),
            nn.BatchNorm3d(64, momentum=1, affine=True),
            nn.ReLU())  # frame/2 x 32 x 32

    def construct(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        return out


class SpatialAttUnit(nn.Cell):
    """Initialize spatial attention unit which refine the aggregation step by re-weighting block contributions.

    Args:
        None.

    Returns:
        Tensor, output tensor.
    """

    def __init__(self):
        super(SpatialAttUnit, self).__init__()
        self.sd1 = nn.SequentialCell(
            nn.Conv3d(64, 16, kernel_size=3, pad_mode='pad', padding=1),
            nn.BatchNorm3d(16, momentum=1, affine=True),
            nn.ReLU(),
            MaxPool3D((2, 1, 1)),  # frame/2 x 32 x 32 x 32
            nn.Conv3d(16, 16, kernel_size=3, pad_mode='pad', padding=1),
            nn.BatchNorm3d(16, momentum=1, affine=True),
            nn.ReLU(),
            MaxPool3D((2, 1, 1)),
            nn.Conv3d(16, 1, kernel_size=1, padding=0),
            nn.Sigmoid())  # 1 x 1 x H x W

    def construct(self, x):
        sd = self.sd1(x)
        return sd


@ClassFactory.register(ModuleType.BACKBONE)
class ARNBackbone(nn.Cell):
    """ARN architecture.

    Args:
        sigma (int):  Initializer for the sigma weight. Default: 100.
        temporal_dim (int):  Number of temporal dimension. Default: 5.
        jigsaw (int): Number of the output dimension for spacial-temporal jigsaw. Default: 10.
        support_num_per_class (int): Number of samples in support set per class. Default: 5.
        query_num_per_class (int): Number of samples in query set per class. Default: 3.
        class_num (int): Number of classes. Default: 5.

    Returns:
        Tensor, output tensor.

    Examples:
        >>> ARNBackbone(100, 5, 10, 5, 3, 5)
    """

    def __init__(self, sigma=100, temporal_dim=5, jigsaw=10, support_num_per_class=5, query_num_per_class=3,
                 class_num=5):
        super(ARNBackbone, self).__init__()
        self.sigma = sigma
        self.temporal_dim = temporal_dim
        self.jigsaw = jigsaw
        self.support_num_per_class = support_num_per_class
        self.query_num_per_class = query_num_per_class
        self.class_num = class_num

        self.feature_encoder = C3DEncoder()
        self.spatial_detector = SpatialAttUnit()

        self.mm = ops.MatMul(transpose_b=True)
        self.sigmoid = ops.Sigmoid()
        self.mean = ops.ReduceMean()
        self.expand = ops.ExpandDims()
        self.transpose = ops.Transpose()
        self.cat_relation = ops.Concat(axis=2)
        self.stack_feature = ops.Stack(axis=0)

    def power_norm(self, x):
        out = 2.0 * self.sigmoid(self.sigma * x) - 1.0
        return out

    def construct(self, data):
        """construct of arn backbone"""
        support = data[0, 0:5, :, :, :, :]
        query = data[0, 5:, :, :, :, :]

        support_features = self.feature_encoder(support)
        query_features = self.feature_encoder(query)

        support_ta = 1 + self.spatial_detector(support_features)
        query_ta = 1 + self.spatial_detector(query_features)

        support_features = (support_features * support_ta).reshape(self.support_num_per_class * self.class_num, 64,
                                                                   self.temporal_dim * 32 * 32)
        query_features = (query_features * query_ta).reshape(self.query_num_per_class * self.class_num, 64,
                                                             self.temporal_dim * 32 * 32)

        so_support_features = []
        so_query_features = []

        for dd in range(support.shape[0]):
            s = support_features[dd, :, :].reshape(64, -1)
            s = (1.0 / s.shape[1]) * self.mm(s, s)
            so_support_features.append(self.power_norm(s / s.trace()))
        so_support_features = self.stack_feature(so_support_features)

        for dd in range(query.shape[0]):
            t = query_features[dd, :, :].view(64, -1)
            t = (1.0 / t.shape[1]) * self.mm(t, t)
            so_query_features.append(self.power_norm(t / t.trace()))
        so_query_features = self.stack_feature(so_query_features)

        so_support_features = so_support_features.reshape(self.class_num, self.support_num_per_class, 1, 64, 64).mean(1)
        so_query_features = self.expand(so_query_features, 1)

        return so_support_features, so_query_features
